import datasets
import jieba

SEPARATOR = "<<<SEP>>>"


DATASETS = ["hc3_chinese", "hc3_english"]


def load_hc3_chinese(cache_dir):
    d = datasets.load_dataset(
        "Hello-SimpleAI/HC3-Chinese",
        "all",
        split="train",
        cache_dir=cache_dir,
        trust_remote_code=True,
    )
    docs = d["human_answers"]
    docs = [doc[0] for doc in docs if doc[0]]
    lens = [len(jieba.lcut(d)) for d in docs]
    sub = [doc for doc, l in zip(docs, lens) if l > 100 and l < 150]

    return sub


def load_hc3_english(cache_dir):
    d = datasets.load_dataset(
        "Hello-SimpleAI/HC3",
        "all",
        split="train",
        cache_dir=cache_dir,
        trust_remote_code=True,
    )
    docs = d["human_answers"]
    docs = [doc[0] for doc in docs if doc[0]]
    lens = [len(d.split()) for d in docs]
    sub = [doc for doc, l in zip(docs, lens) if l > 100 and l < 150]

    return sub


def load(name, cache_dir, **kwargs):
    if name in DATASETS:
        load_fn = globals()[f"load_{name}"]
        return load_fn(cache_dir=cache_dir, **kwargs)
    else:
        raise ValueError(f"Unknown dataset {name}")
